# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""[Preview] Live API client."""

import asyncio
import base64
import contextlib
import json
import logging
import typing
from typing import Any, AsyncIterator, Optional, Sequence, Union, get_args
import warnings

import google.auth
import pydantic
from websockets import ConnectionClosed

from . import _api_module
from . import _common
from . import _live_converters as live_converters
from . import _mcp_utils
from . import _transformers as t
from . import errors
from . import types
from ._api_client import BaseApiClient
from ._common import get_value_by_path as getv
from ._common import set_value_by_path as setv
from .live_music import AsyncLiveMusic
from .models import _Content_to_mldev


try:
  from websockets.asyncio.client import ClientConnection
  from websockets.asyncio.client import connect as ws_connect
except ModuleNotFoundError:
  # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
  from websockets.client import ClientConnection  # type: ignore
  from websockets.client import connect as ws_connect  # type: ignore

try:
  from google.auth.transport import requests
except ImportError:
  requests = None  # type: ignore[assignment]

if typing.TYPE_CHECKING:
  from mcp import ClientSession as McpClientSession
  from mcp.types import Tool as McpTool
  from ._adapters import McpToGenAiToolAdapter
  from ._mcp_utils import mcp_to_gemini_tool
else:
  McpClientSession: typing.Type = Any
  McpTool: typing.Type = Any
  McpToGenAiToolAdapter: typing.Type = Any
  try:
    from mcp import ClientSession as McpClientSession
    from mcp.types import Tool as McpTool
    from ._adapters import McpToGenAiToolAdapter
    from ._mcp_utils import mcp_to_gemini_tool
  except ImportError:
    McpClientSession = None
    McpTool = None
    McpToGenAiToolAdapter = None
    mcp_to_gemini_tool = None

logger = logging.getLogger('google_genai.live')

_FUNCTION_RESPONSE_REQUIRES_ID = (
    'FunctionResponse request must have an `id` field from the'
    ' response of a ToolCall.FunctionalCalls in Google AI.'
)


class AsyncSession:
  """[Preview] AsyncSession."""

  def __init__(
      self,
      api_client: BaseApiClient,
      websocket: ClientConnection,
      session_id: Optional[str] = None,
  ):
    self._api_client = api_client
    self._ws = websocket
    self.session_id = session_id

  async def send(
      self,
      *,
      input: Optional[
          Union[
              types.ContentListUnion,
              types.ContentListUnionDict,
              types.LiveClientContentOrDict,
              types.LiveClientRealtimeInputOrDict,
              types.LiveClientToolResponseOrDict,
              types.FunctionResponseOrDict,
              Sequence[types.FunctionResponseOrDict],
          ]
      ] = None,
      end_of_turn: Optional[bool] = False,
  ) -> None:
    """[Deprecated] Send input to the model.

    > **Warning**: This method is deprecated and will be removed in a future
    version (not before Q3 2025). Please use one of the more specific methods:
    `send_client_content`, `send_realtime_input`, or `send_tool_response`
    instead.

    The method will send the input request to the server.

    Args:
      input: The input request to the model.
      end_of_turn: Whether the input is the last message in a turn.

    Example usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)

      async with client.aio.live.connect(model='...') as session:
        await session.send(input='Hello world!', end_of_turn=True)
        async for message in session.receive():
          print(message)
    """
    warnings.warn(
        'The `session.send` method is deprecated and will be removed in a '
        'future version (not before Q3 2025).\n'
        'Please use one of the more specific methods: `send_client_content`, '
        '`send_realtime_input`, or `send_tool_response` instead.',
        DeprecationWarning,
        stacklevel=2,
    )
    client_message = self._parse_client_message(input, end_of_turn)
    await self._ws.send(json.dumps(client_message))

  async def send_client_content(
      self,
      *,
      turns: Optional[
          Union[
              types.Content,
              types.ContentDict,
              list[Union[types.Content, types.ContentDict]],
          ]
      ] = None,
      turn_complete: bool = True,
  ) -> None:
    """Send non-realtime, turn based content to the model.

    There are two ways to send messages to the live API:
    `send_client_content` and `send_realtime_input`.

    `send_client_content` messages are added to the model context **in order**.
    Having a conversation using `send_client_content` messages is roughly
    equivalent to using the `Chat.send_message_stream` method, except that the
    state of the `chat` history is stored on the API server.

    Because of `send_client_content`'s order guarantee, the model cannot
    respond as quickly to `send_client_content` messages as to
    `send_realtime_input` messages. This makes the biggest difference when
    sending objects that have significant preprocessing time (typically images).

    The `send_client_content` message sends a list of `Content` objects,
    which has more options than the `media:Blob` sent by `send_realtime_input`.

    The main use-cases for `send_client_content` over `send_realtime_input` are:

    - Prefilling a conversation context (including sending anything that can't
      be represented as a realtime message), before starting a realtime
      conversation.
    - Conducting a non-realtime conversation, similar to `client.chat`, using
      the live api.

    Caution: Interleaving `send_client_content` and `send_realtime_input`
      in the same conversation is not recommended and can lead to unexpected
      results.

    Args:
      turns: A `Content` object or list of `Content` objects (or equivalent
        dicts).
      turn_complete: if true (the default) the model will reply immediately. If
        false, the model will wait for you to send additional client_content,
        and will not return until you send `turn_complete=True`.

    Example:

    .. code-block:: python

      import google.genai
      from google.genai import types
      import os

      if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
        MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
      else:
        MODEL_NAME = 'gemini-live-2.5-flash-preview';

      client = genai.Client()
      async with client.aio.live.connect(
          model=MODEL_NAME,
          config={"response_modalities": ["TEXT"]}
      ) as session:
        await session.send_client_content(
            turns=types.Content(
                role='user',
                parts=[types.Part(text="Hello world!")]))
        async for msg in session.receive():
          if msg.text:
            print(msg.text)
    """
    client_content = t.t_client_content(turns, turn_complete).model_dump(
        mode='json', exclude_none=True
    )

    if self._api_client.vertexai:
      client_content_dict = _common.convert_to_dict(
          client_content, convert_keys=True
      )
    else:
      client_content_dict = live_converters._LiveClientContent_to_mldev(
          from_object=client_content
      )

    await self._ws.send(json.dumps({'client_content': client_content_dict}))

  async def send_realtime_input(
      self,
      *,
      media: Optional[types.BlobImageUnionDict] = None,
      audio: Optional[types.BlobOrDict] = None,
      audio_stream_end: Optional[bool] = None,
      video: Optional[types.BlobImageUnionDict] = None,
      text: Optional[str] = None,
      activity_start: Optional[types.ActivityStartOrDict] = None,
      activity_end: Optional[types.ActivityEndOrDict] = None,
  ) -> None:
    """Send realtime input to the model, only send one argument per call.

    Use `send_realtime_input` for realtime audio chunks and video
    frames(images).

    With `send_realtime_input` the api will respond to audio automatically
    based on voice activity detection (VAD).

    `send_realtime_input` is optimized for responsivness at the expense of
    deterministic ordering. Audio and video tokens are added to the
    context when they become available.

    Args:
      media: A `Blob`-like object, the realtime media to send.

    Example:

    .. code-block:: python

      from pathlib import Path

      from google import genai
      from google.genai import types

      import PIL.Image

      import os

      if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
        MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
      else:
        MODEL_NAME = 'gemini-live-2.5-flash-preview';


      client = genai.Client()

      async with client.aio.live.connect(
          model=MODEL_NAME,
          config={"response_modalities": ["TEXT"]},
      ) as session:
        await session.send_realtime_input(
            media=PIL.Image.open('image.jpg'))

        audio_bytes = Path('audio.pcm').read_bytes()
        await session.send_realtime_input(
            media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))

        async for msg in session.receive():
          if msg.text is not None:
            print(f'{msg.text}')
    """
    kwargs: _common.StringDict = {}
    if media is not None:
      kwargs['media'] = media
    if audio is not None:
      kwargs['audio'] = audio
    if audio_stream_end is not None:
      kwargs['audio_stream_end'] = audio_stream_end
    if video is not None:
      kwargs['video'] = video
    if text is not None:
      kwargs['text'] = text
    if activity_start is not None:
      kwargs['activity_start'] = activity_start
    if activity_end is not None:
      kwargs['activity_end'] = activity_end

    if len(kwargs) != 1:
      raise ValueError(
          f'Only one argument can be set, got {len(kwargs)}:'
          f' {list(kwargs.keys())}'
      )
    realtime_input = types.LiveSendRealtimeInputParameters.model_validate(
        kwargs
    )

    if self._api_client.vertexai:
      realtime_input_dict = (
          live_converters._LiveSendRealtimeInputParameters_to_vertex(
              from_object=realtime_input
          )
      )
    else:
      realtime_input_dict = (
          live_converters._LiveSendRealtimeInputParameters_to_mldev(
              from_object=realtime_input
          )
      )
    realtime_input_dict = _common.convert_to_dict(realtime_input_dict)
    realtime_input_dict = _common.encode_unserializable_types(
        realtime_input_dict
    )
    await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))

  async def send_tool_response(
      self,
      *,
      function_responses: Union[
          types.FunctionResponseOrDict,
          Sequence[types.FunctionResponseOrDict],
      ],
  ) -> None:
    """Send a tool response to the session.

    Use `send_tool_response` to reply to `LiveServerToolCall` messages
    from the server.

    To set the available tools, use the `config.tools` argument
    when you connect to the session (`client.live.connect`).

    Args:
      function_responses: A `FunctionResponse`-like object or list of
        `FunctionResponse`-like objects.

    Example:

    .. code-block:: python

      from google import genai
      from google.genai import types

      import os

      if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
        MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
      else:
        MODEL_NAME = 'gemini-live-2.5-flash-preview';

      client = genai.Client()

      tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
      config = {
          "tools": tools,
          "response_modalities": ['TEXT']
      }

      async with client.aio.live.connect(
          model='models/gemini-live-2.5-flash-preview',
          config=config
      ) as session:
        prompt = "Turn on the lights please"
        await session.send_client_content(
            turns={"parts": [{'text': prompt}]}
        )

        async for chunk in session.receive():
            if chunk.server_content:
              if chunk.text is not None:
                print(chunk.text)
            elif chunk.tool_call:
              print(chunk.tool_call)
              print('_'*80)
              function_response=types.FunctionResponse(
                      name='turn_on_the_lights',
                      response={'result': 'ok'},
                      id=chunk.tool_call.function_calls[0].id,
                  )
              print(function_response)
              await session.send_tool_response(
                  function_responses=function_response
              )

              print('_'*80)
    """
    tool_response = t.t_tool_response(function_responses)
    if self._api_client.vertexai:
      tool_response_dict = _common.convert_to_dict(
          tool_response, convert_keys=True
      )
    else:
      tool_response_dict = _common.convert_to_dict(
          tool_response, convert_keys=True
      )
      for response in tool_response_dict.get('functionResponses', []):
        if response.get('id') is None:
          raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)

    await self._ws.send(json.dumps({'tool_response': tool_response_dict}))

  async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
    """Receive model responses from the server.

    The method will yield the model responses from the server. The returned
    responses will represent a complete model turn. When the returned message
    is function call, user must call `send` with the function response to
    continue the turn.

    Yields:
      The model responses from the server.

    Example usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)

      async with client.aio.live.connect(model='...') as session:
        await session.send(input='Hello world!', end_of_turn=True)
        async for message in session.receive():
          print(message)
    """
    # TODO(b/365983264) Handle intermittent issues for the user.
    while result := await self._receive():
      if result.server_content and result.server_content.turn_complete:
        yield result
        break
      yield result

  async def start_stream(
      self, *, stream: AsyncIterator[bytes], mime_type: str
  ) -> AsyncIterator[types.LiveServerMessage]:
    """[Deprecated] Start a live session from a data stream.

    > **Warning**: This method is deprecated and will be removed in a future
    version (not before Q2 2025). Please use one of the more specific methods:
    `send_client_content`, `send_realtime_input`, or `send_tool_response`
    instead.

    The interaction terminates when the input stream is complete.
    This method will start two async tasks. One task will be used to send the
    input stream to the model and the other task will be used to receive the
    responses from the model.

    Args:
      stream: An iterator that yields the model response.
      mime_type: The MIME type of the data in the stream.

    Yields:
      The audio bytes received from the model and server response messages.

    Example usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)
      config = {'response_modalities': ['AUDIO']}
      async def audio_stream():
        stream = read_audio()
        for data in stream:
          yield data
      async with client.aio.live.connect(model='...', config=config) as session:
        for audio in session.start_stream(stream = audio_stream(),
        mime_type = 'audio/pcm'):
          play_audio_chunk(audio.data)
    """
    warnings.warn(
        'Setting `AsyncSession.start_stream` is deprecated, '
        'and will be removed in a future release (not before Q3 2025). '
        'Please use the `receive`, and `send_realtime_input`, methods instead.',
        DeprecationWarning,
        stacklevel=4,
    )
    stop_event = asyncio.Event()
    # Start the send loop. When stream is complete stop_event is set.
    asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
    recv_task = None
    while not stop_event.is_set():
      try:
        recv_task = asyncio.create_task(self._receive())
        await asyncio.wait(
            [
                recv_task,
                asyncio.create_task(stop_event.wait()),
            ],
            return_when=asyncio.FIRST_COMPLETED,
        )
        if recv_task.done():
          yield recv_task.result()
          # Give a chance for the send loop to process requests.
          await asyncio.sleep(10**-12)
      except ConnectionClosed:
        break
    if recv_task is not None and not recv_task.done():
      recv_task.cancel()
      # Wait for the task to finish (cancelled or not)
      try:
        await recv_task
      except asyncio.CancelledError:
        pass

  async def _receive(self) -> types.LiveServerMessage:
    parameter_model = types.LiveServerMessage()
    try:
      raw_response = await self._ws.recv(decode=False)
    except TypeError:
      raw_response = await self._ws.recv()  # type: ignore[assignment]
    if raw_response:
      try:
        response = json.loads(raw_response)
      except json.decoder.JSONDecodeError:
        raise ValueError(f'Failed to parse response: {raw_response!r}')
    else:
      response = {}

    if self._api_client.vertexai:
      response_dict = live_converters._LiveServerMessage_from_vertex(response)
    else:
      response_dict = response

    return types.LiveServerMessage._from_response(
        response=response_dict, kwargs=parameter_model.model_dump()
    )

  async def _send_loop(
      self,
      data_stream: AsyncIterator[bytes],
      mime_type: str,
      stop_event: asyncio.Event,
  ) -> None:
    async for data in data_stream:
      model_input = types.LiveClientRealtimeInput(
          media_chunks=[types.Blob(data=data, mime_type=mime_type)]
      )
      await self.send(input=model_input)
      # Give a chance for the receive loop to process responses.
      await asyncio.sleep(10**-12)
    # Give a chance for the receiver to process the last response.
    stop_event.set()

  def _parse_client_message(
      self,
      input: Optional[
          Union[
              types.ContentListUnion,
              types.ContentListUnionDict,
              types.LiveClientContentOrDict,
              types.LiveClientRealtimeInputOrDict,
              types.LiveClientToolResponseOrDict,
              types.FunctionResponseOrDict,
              Sequence[types.FunctionResponseOrDict],
          ]
      ] = None,
      end_of_turn: Optional[bool] = False,
  ) -> types.LiveClientMessageDict:

    formatted_input: Any = input

    if not input:
      logging.info('No input provided. Assume it is the end of turn.')
      return {'client_content': {'turn_complete': True}}
    if isinstance(input, str):
      formatted_input = [input]
    elif isinstance(input, dict) and 'data' in input:
      try:
        blob_input = types.Blob(**input)
      except pydantic.ValidationError:
        raise ValueError(
            f'Unsupported input type "{type(input)}" or input content "{input}"'
        )
      if isinstance(blob_input, types.Blob) and isinstance(
          blob_input.data, bytes
      ):
        formatted_input = [
            blob_input.model_dump(mode='json', exclude_none=True)
        ]
    elif isinstance(input, types.Blob):
      formatted_input = [input]
    elif isinstance(input, dict) and 'name' in input and 'response' in input:
      # ToolResponse.FunctionResponse
      if not (self._api_client.vertexai) and 'id' not in input:
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      formatted_input = [input]

    if isinstance(formatted_input, Sequence) and any(
        isinstance(c, dict) and 'name' in c and 'response' in c
        for c in formatted_input
    ):
      # ToolResponse.FunctionResponse
      function_responses_input = []
      for item in formatted_input:
        if isinstance(item, dict):
          try:
            function_response_input = types.FunctionResponse(**item)
          except pydantic.ValidationError:
            raise ValueError(
                f'Unsupported input type "{type(input)}" or input content'
                f' "{input}"'
            )
          if (
              function_response_input.id is None
              and not self._api_client.vertexai
          ):
            raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
          else:
            function_response_dict = function_response_input.model_dump(
                exclude_none=True, mode='json'
            )
            function_response_typeddict = types.FunctionResponseDict(
                name=function_response_dict.get('name'),
                response=function_response_dict.get('response'),
            )
            if function_response_dict.get('id'):
              function_response_typeddict['id'] = function_response_dict.get(
                  'id'
              )
            function_responses_input.append(function_response_typeddict)
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=function_responses_input
          )
      )
    elif isinstance(formatted_input, Sequence) and any(
        isinstance(c, str) for c in formatted_input
    ):
      to_object: _common.StringDict = {}
      content_input_parts: list[types.PartUnion] = []
      for item in formatted_input:
        if isinstance(item, get_args(types.PartUnion)):
          content_input_parts.append(item)
      if self._api_client.vertexai:
        contents = [
            _common.convert_to_dict(item, convert_keys=True)
            for item in t.t_contents(content_input_parts)
        ]
      else:
        contents = [
            _Content_to_mldev(item, to_object)
            for item in t.t_contents(content_input_parts)
        ]

      content_dict_list: list[types.ContentDict] = []
      for item in contents:
        try:
          content_input = types.Content(**item)
        except pydantic.ValidationError:
          raise ValueError(
              f'Unsupported input type "{type(input)}" or input content'
              f' "{input}"'
          )
        content_dict_list.append(
            types.ContentDict(
                parts=content_input.model_dump(exclude_none=True, mode='json')[
                    'parts'
                ],
                role=content_input.role,
            )
        )

      client_message = types.LiveClientMessageDict(
          client_content=types.LiveClientContentDict(
              turns=content_dict_list, turn_complete=end_of_turn
          )
      )
    elif isinstance(formatted_input, Sequence):
      if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
        pass
      elif any(isinstance(b, types.Blob) for b in formatted_input):
        formatted_input = [
            b.model_dump(exclude_none=True, mode='json')
            for b in formatted_input
        ]
      else:
        raise ValueError(
            f'Unsupported input type "{type(input)}" or input content "{input}"'
        )

      client_message = types.LiveClientMessageDict(
          realtime_input=types.LiveClientRealtimeInputDict(
              media_chunks=formatted_input
          )
      )

    elif isinstance(formatted_input, dict):
      if 'content' in formatted_input or 'turns' in formatted_input:
        # TODO(b/365983264) Add validation checks for content_update input_dict.
        if 'turns' in formatted_input:
          content_turns = formatted_input['turns']
        else:
          content_turns = formatted_input['content']
        client_message = types.LiveClientMessageDict(
            client_content=types.LiveClientContentDict(
                turns=content_turns,
                turn_complete=formatted_input.get('turn_complete'),
            )
        )
      elif 'media_chunks' in formatted_input:
        try:
          realtime_input = types.LiveClientRealtimeInput(**formatted_input)
        except pydantic.ValidationError:
          raise ValueError(
              f'Unsupported input type "{type(input)}" or input content'
              f' "{input}"'
          )
        client_message = types.LiveClientMessageDict(
            realtime_input=types.LiveClientRealtimeInputDict(
                media_chunks=realtime_input.model_dump(
                    exclude_none=True, mode='json'
                )['media_chunks']
            )
        )
      elif 'function_responses' in formatted_input:
        try:
          tool_response_input = types.LiveClientToolResponse(**formatted_input)
        except pydantic.ValidationError:
          raise ValueError(
              f'Unsupported input type "{type(input)}" or input content'
              f' "{input}"'
          )
        client_message = types.LiveClientMessageDict(
            tool_response=types.LiveClientToolResponseDict(
                function_responses=tool_response_input.model_dump(
                    exclude_none=True, mode='json'
                )['function_responses']
            )
        )
      else:
        raise ValueError(
            f'Unsupported input type "{type(input)}" or input content "{input}"'
        )
    elif isinstance(formatted_input, types.LiveClientRealtimeInput):
      realtime_input_dict = formatted_input.model_dump(
          exclude_none=True, mode='json'
      )
      client_message = types.LiveClientMessageDict(
          realtime_input=types.LiveClientRealtimeInputDict(
              media_chunks=realtime_input_dict.get('media_chunks')
          )
      )
      if (
          client_message['realtime_input'] is not None
          and client_message['realtime_input']['media_chunks'] is not None
          and isinstance(
              client_message['realtime_input']['media_chunks'][0]['data'], bytes
          )
      ):
        formatted_media_chunks: list[types.BlobDict] = []
        for item in client_message['realtime_input']['media_chunks']:
          if isinstance(item, dict):
            try:
              blob_input = types.Blob(**item)
            except pydantic.ValidationError:
              raise ValueError(
                  f'Unsupported input type "{type(input)}" or input content'
                  f' "{input}"'
              )
            if (
                isinstance(blob_input, types.Blob)
                and isinstance(blob_input.data, bytes)
                and blob_input.data is not None
            ):
              formatted_media_chunks.append(
                  types.BlobDict(
                      data=base64.b64decode(blob_input.data),
                      mime_type=blob_input.mime_type,
                  )
              )

        client_message['realtime_input'][
            'media_chunks'
        ] = formatted_media_chunks

    elif isinstance(formatted_input, types.LiveClientContent):
      client_content_dict = formatted_input.model_dump(
          exclude_none=True, mode='json'
      )
      client_message = types.LiveClientMessageDict(
          client_content=types.LiveClientContentDict(
              turns=client_content_dict.get('turns'),
              turn_complete=client_content_dict.get('turn_complete'),
          )
      )
    elif isinstance(formatted_input, types.LiveClientToolResponse):
      # ToolResponse.FunctionResponse
      if (
          not (self._api_client.vertexai)
          and formatted_input.function_responses is not None
          and not (formatted_input.function_responses[0].id)
      ):
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=formatted_input.model_dump(
                  exclude_none=True, mode='json'
              ).get('function_responses')
          )
      )
    elif isinstance(formatted_input, types.FunctionResponse):
      if not (self._api_client.vertexai) and not (formatted_input.id):
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      function_response_dict = formatted_input.model_dump(
          exclude_none=True, mode='json'
      )
      function_response_typeddict = types.FunctionResponseDict(
          name=function_response_dict.get('name'),
          response=function_response_dict.get('response'),
      )
      if function_response_dict.get('id'):
        function_response_typeddict['id'] = function_response_dict.get('id')
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=[function_response_typeddict]
          )
      )
    elif isinstance(formatted_input, Sequence) and isinstance(
        formatted_input[0], types.FunctionResponse
    ):
      if not (self._api_client.vertexai) and not (formatted_input[0].id):
        raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
      function_response_list: list[types.FunctionResponseDict] = []
      for item in formatted_input:
        function_response_dict = item.model_dump(exclude_none=True, mode='json')
        function_response_typeddict = types.FunctionResponseDict(
            name=function_response_dict.get('name'),
            response=function_response_dict.get('response'),
        )
        if function_response_dict.get('id'):
          function_response_typeddict['id'] = function_response_dict.get('id')
        function_response_list.append(function_response_typeddict)
      client_message = types.LiveClientMessageDict(
          tool_response=types.LiveClientToolResponseDict(
              function_responses=function_response_list
          )
      )

    else:
      raise ValueError(
          f'Unsupported input type "{type(input)}" or input content "{input}"'
      )

    return client_message

  async def close(self) -> None:
    # Close the websocket connection.
    await self._ws.close()


class AsyncLive(_api_module.BaseModule):
  """[Preview] AsyncLive."""

  def __init__(self, api_client: BaseApiClient):
    super().__init__(api_client)
    self._music = AsyncLiveMusic(api_client)

  @property
  def music(self) -> AsyncLiveMusic:
    return self._music

  @contextlib.asynccontextmanager
  async def connect(
      self,
      *,
      model: str,
      config: Optional[types.LiveConnectConfigOrDict] = None,
  ) -> AsyncIterator[AsyncSession]:
    """[Preview] Connect to the live server.

    Note: the live API is currently in preview.

    Usage:

    .. code-block:: python

      client = genai.Client(api_key=API_KEY)
      config = {}
      async with client.aio.live.connect(model='...', config=config) as session:
        await session.send_client_content(
          turns=types.Content(
            role='user',
            parts=[types.Part(text='hello!')]
          ),
          turn_complete=True
        )
        async for message in session.receive():
          print(message)

    Args:
      model: The model to use for the live session.
      config: The configuration for the live session.
      **kwargs: additional keyword arguments.

    Yields:
      An AsyncSession object.
    """
    # TODO(b/404946570): Support per request http options.
    if isinstance(config, dict):
      config = types.LiveConnectConfig(**config)
    if config and config.http_options:
      raise ValueError(
          'google.genai.client.aio.live.connect() does not support'
          ' http_options at request-level in LiveConnectConfig yet. Please use'
          ' the client-level http_options configuration instead.'
      )

    base_url = self._api_client._websocket_base_url()
    if isinstance(base_url, bytes):
      base_url = base_url.decode('utf-8')
    transformed_model = t.t_model(self._api_client, model)  # type: ignore

    parameter_model = await _t_live_connect_config(self._api_client, config)

    if self._api_client.api_key and not self._api_client.vertexai:
      version = self._api_client._http_options.api_version
      api_key = self._api_client.api_key
      method = 'BidiGenerateContent'
      original_headers = self._api_client._http_options.headers
      headers = original_headers.copy() if original_headers is not None else {}
      if api_key.startswith('auth_tokens/'):
        warnings.warn(
            message=(
                "The SDK's ephemeral token support is experimental, and may"
                ' change in future versions.'
            ),
            category=errors.ExperimentalWarning,
        )
        method = 'BidiGenerateContentConstrained'
        headers['Authorization'] = f'Token {api_key}'
        if version != 'v1alpha':
          warnings.warn(
              message=(
                  "The SDK's ephemeral token support is in v1alpha only."
                  'Please use client = genai.Client(api_key=token.name, '
                  'http_options=types.HttpOptions(api_version="v1alpha"))'
                  ' before session connection.'
              ),
              category=errors.ExperimentalWarning,
          )
      uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}'

      request_dict = _common.convert_to_dict(
          live_converters._LiveConnectParameters_to_mldev(
              api_client=self._api_client,
              from_object=types.LiveConnectParameters(
                  model=transformed_model,
                  config=parameter_model,
              ).model_dump(exclude_none=True),
          )
      )
      del request_dict['config']

      setv(request_dict, ['setup', 'model'], transformed_model)

      request = json.dumps(request_dict)
    elif self._api_client.api_key and self._api_client.vertexai:
      # Headers already contains api key for express mode.
      api_key = self._api_client.api_key
      version = self._api_client._http_options.api_version
      uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
      original_headers = self._api_client._http_options.headers
      headers = original_headers.copy() if original_headers is not None else {}

      request_dict = _common.convert_to_dict(
          live_converters._LiveConnectParameters_to_vertex(
              api_client=self._api_client,
              from_object=types.LiveConnectParameters(
                  model=transformed_model,
                  config=parameter_model,
              ).model_dump(exclude_none=True),
          )
      )
      del request_dict['config']

      setv(request_dict, ['setup', 'model'], transformed_model)

      request = json.dumps(request_dict)
    else:
      version = self._api_client._http_options.api_version
      has_sufficient_auth = (
          self._api_client.project and self._api_client.location
      )
      if self._api_client.custom_base_url and not has_sufficient_auth:
        # API gateway proxy can use the auth in custom headers, not url.
        # Enable custom url if auth is not sufficient.
        uri = self._api_client.custom_base_url
        # Keep the model as is.
        transformed_model = model
        # Do not get credentials for custom url.
        original_headers = self._api_client._http_options.headers
        headers = (
            original_headers.copy() if original_headers is not None else {}
        )

      else:
        uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'

        if not self._api_client._credentials:
          # Get bearer token through Application Default Credentials.
          creds, _ = google.auth.default(  # type: ignore
              scopes=['https://www.googleapis.com/auth/cloud-platform']
          )
        else:
          creds = self._api_client._credentials
        # creds.valid is False, and creds.token is None
        # Need to refresh credentials to populate those
        if not (creds.token and creds.valid):
          if requests is None:
            raise ValueError('The requests module is required to refresh google-auth credentials. Please install with `pip install google-auth[requests]`')
          auth_req = requests.Request()  # type: ignore
          creds.refresh(auth_req)
        bearer_token = creds.token

        original_headers = self._api_client._http_options.headers
        headers = (
            original_headers.copy() if original_headers is not None else {}
        )
        if not headers.get('Authorization'):
          headers['Authorization'] = f'Bearer {bearer_token}'

      location = self._api_client.location
      project = self._api_client.project
      if transformed_model.startswith('publishers/') and project and location:
        transformed_model = (
            f'projects/{project}/locations/{location}/' + transformed_model
        )
      request_dict = _common.convert_to_dict(
          live_converters._LiveConnectParameters_to_vertex(
              api_client=self._api_client,
              from_object=types.LiveConnectParameters(
                  model=transformed_model,
                  config=parameter_model,
              ).model_dump(exclude_none=True),
          )
      )
      del request_dict['config']

      if (
          getv(
              request_dict, ['setup', 'generationConfig', 'responseModalities']
          )
          is None
      ):
        setv(
            request_dict,
            ['setup', 'generationConfig', 'responseModalities'],
            ['AUDIO'],
        )

      request = json.dumps(request_dict)

    if parameter_model.tools and _mcp_utils.has_mcp_tool_usage(
        parameter_model.tools
    ):
      if headers is None:
        headers = {}
      _mcp_utils.set_mcp_usage_header(headers)

    async with ws_connect(
        uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx
    ) as ws:
      await ws.send(request)
      try:
        # websockets 14.0+
        raw_response = await ws.recv(decode=False)
      except TypeError:
        raw_response = await ws.recv()  # type: ignore[assignment]
      if raw_response:
        try:
          response = json.loads(raw_response)
        except json.decoder.JSONDecodeError:
          raise ValueError(f'Failed to parse response: {raw_response!r}')
      else:
        response = {}

      if self._api_client.vertexai:
        response_dict = live_converters._LiveServerMessage_from_vertex(response)
      else:
        response_dict = response

      setup_response = types.LiveServerMessage._from_response(
          response=response_dict, kwargs=parameter_model.model_dump()
      )
      if setup_response.setup_complete:
        session_id = setup_response.setup_complete.session_id
      else:
        session_id = None
      yield AsyncSession(
          api_client=self._api_client,
          websocket=ws,
          session_id=session_id,
      )


async def _t_live_connect_config(
    api_client: BaseApiClient,
    config: Optional[types.LiveConnectConfigOrDict],
) -> types.LiveConnectConfig:
  # Ensure the config is a LiveConnectConfig.
  if config is None:
    parameter_model = types.LiveConnectConfig()
  elif isinstance(config, dict):
    if getv(config, ['system_instruction']) is not None:
      converted_system_instruction = t.t_content(
          getv(config, ['system_instruction'])
      )
    else:
      converted_system_instruction = None
    parameter_model = types.LiveConnectConfig(**config)
    parameter_model.system_instruction = converted_system_instruction
  else:
    if config.system_instruction is None:
      system_instruction = None
    else:
      system_instruction = t.t_content(getv(config, ['system_instruction']))
    parameter_model = config
    parameter_model.system_instruction = system_instruction

  # Create a copy of the config model with the tools field cleared as they will
  # be replaced with the MCP tools converted to GenAI tools.
  parameter_model_copy = parameter_model.model_copy(update={'tools': None})
  if parameter_model.tools:
    parameter_model_copy.tools = []
    for tool in parameter_model.tools:
      if McpClientSession is not None and isinstance(tool, McpClientSession):
        mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
            tool, await tool.list_tools()
        )
        # Extend the config with the MCP session tools converted to GenAI tools.
        parameter_model_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
      elif McpTool is not None and isinstance(tool, McpTool):
        parameter_model_copy.tools.append(mcp_to_gemini_tool(tool))
      else:
        parameter_model_copy.tools.append(tool)

  if parameter_model_copy.generation_config is not None:
    warnings.warn(
        'Setting `LiveConnectConfig.generation_config` is deprecated, '
        'please set the fields on `LiveConnectConfig` directly. This will '
        'become an error in a future version (not before Q3 2025)',
        DeprecationWarning,
        stacklevel=4,
    )

  return parameter_model_copy
